[ad8447]: / (1) PyTorch_HistoNet / functions / visualize_model.py

Download this file

38 lines (32 with data), 1.2 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import util.imshow as imshow
import numpy as np
def visualize_model(model, dataloaders, cuda, class_names, num_images=6):
was_training = model.training
model.eval()
images_so_far = 0
fig = plt.figure(1)
with torch.no_grad():
for i, (inputs, dummyTarget, filename, labels) in enumerate(dataloaders):
if cuda:
inputs = inputs.to('cuda')
outputs = model(inputs)
m = nn.Sigmoid()
#preds = torch.round(m(outputs)).int()
preds = (m(outputs) > 0.5).int()
for j in range(inputs.size()[0]):
labelJ = labels[j].numpy()
predJ = preds[j].cpu().numpy()
images_so_far += 1
ax = plt.subplot(num_images//2, 2, images_so_far)
ax.axis('off')
ax.set_title('real: {}; predicted: {}'.format(labelJ, predJ))
imshow(inputs.cpu().data[j])
if images_so_far == num_images:
model.train(mode=was_training)
plt.show()
return
plt.show()
model.train(mode=was_training)